from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

import torch
import numpy as np
import torchvision
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
from torchvision import transforms
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from torch.utils.data import Dataset, Subset
from torch.nn.modules.loss import TripletMarginLoss
from itertools import permutations
from datetime import datetime
import time
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]

    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, data, device, yield_labels=True):
        self.data = data
        self.device = device
        self.yield_labels = yield_labels

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        if self.yield_labels:
            for data, l in self.data:
                yield to_device(data, self.device), l
        else:
            for data in self.data:
                yield to_device(data, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.data)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=3, output_size=2, pow=2.0, encoder = 'No'):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.encoder = encoder
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, output_size)

        self.p_dist = lambda a, b: torch.mean(nn.PairwiseDistance(p=pow, eps=0, keepdim=False)(a, b))
        self.p_dist_sum = lambda a, b: torch.sum(nn.PairwiseDistance(p=pow, eps=0, keepdim=False)(a, b))
        self.cos_sim = lambda a, b: 1.0 - torch.mean(nn.CosineSimilarity(dim=1, eps=1e-08)(a, b))
        self.gauss_sim = lambda a, b: torch.sum(1.0 - torch.exp(-(nn.PairwiseDistance(p=2.0, eps=0, keepdim=False)(a, b) ** 2) / (2.0 * 0.125)))

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out_layer1 = self.layer1(out)
        out_layer2 = self.layer2(out_layer1)
        out_layer3 = self.layer3(out_layer2)
        out_layer4 = self.layer4(out_layer3)
        out = F.avg_pool2d(out_layer4, 4)
        out = out.view(out.size(0), -1)
        embedded = self.linear(out)
        if self.encoder == 'Yes':
            return out_layer4
        elif self.encoder == 'No':
            return embedded, out_layer1, out_layer2, out_layer3, out_layer4

    def training_step(self, batch):
        images, embeddings, _ = batch

        # Generate predictions
        out = self(images)

        # Calculate loss
        loss = self.p_dist(out, embeddings)

        return loss
    

    def extract_feature(self, dataset, mode='train'):
        size_batch = len(dataset)//30 #change this if not enough memory

        data_loader = DataLoader(dataset=dataset, batch_size=size_batch, shuffle=False)
        data_loader = DeviceDataLoader(data_loader, 'cuda')
        
        if mode == 'train':
            with torch.no_grad():
                self.train()
                embedding = torch.tensor([])
                gt_layer3 = torch.tensor([])
                gt_layer2 = torch.tensor([])
                gt_layer1 = torch.tensor([])
                gt_layer0 = torch.tensor([])
                for batch in data_loader:
                    data, _ = batch
                    embedded, out_layer3, out_layer2, out_layer1, out_layer0 = self.forward(data)
                    embedding = torch.cat((embedding, embedded.cpu()), 0)
                    gt_layer3 = torch.cat((gt_layer3, out_layer3.cpu()), 0)
                    gt_layer2 = torch.cat((gt_layer2, out_layer2.cpu()), 0)
                    gt_layer1 = torch.cat((gt_layer1, out_layer1.cpu()), 0)
                    gt_layer0 = torch.cat((gt_layer0, out_layer0.cpu()), 0)
    
            embedding = embedding.cpu().numpy()  # embedding.t().cpu().numpy()
            label = torch.tensor(dataset.targets).cpu().numpy()
    
            del data_loader, embedded, data, out_layer0, out_layer1, out_layer2, out_layer3
            torch.cuda.empty_cache()  # PyTorch thing
            return embedding, gt_layer3, gt_layer2, gt_layer1, gt_layer0
        elif mode == 'test':
            with torch.no_grad():
                self.eval()
                gt_layer0 = torch.tensor([])
                for batch in data_loader:
                    data, _ = batch
                    embedded, out_layer3, out_layer2, out_layer1, out_layer0 = self.forward(data)
                    gt_layer0 = torch.cat((gt_layer0, out_layer0.cpu()), 0)
    
            del data_loader, embedded, data, out_layer0, out_layer1, out_layer2, out_layer3
            torch.cuda.empty_cache()  # PyTorch thing
            return gt_layer0                
    

    def evaluate_step(self, val_loader):
        with torch.no_grad():
            self.eval()

            val_loss = []
            #val_acc = []
            for batch in val_loader:
                images, embeddings, labels = batch

                # Generate predictions
                out_embedding = self(images)

                # Calculate loss
                loss = self.p_dist(out_embedding, embeddings)

                val_loss.append(loss.item())

        epoch_loss = torch.tensor(val_loss).mean()  # Combine losses

        return {'val_loss': epoch_loss.item()}

def ResNet_size(size=10, in_channels=3, output_size=2, encoder = 'No'):
    if size == 10:
        return ResNet(BasicBlock, [1, 1, 1, 1], in_channels=in_channels, output_size=output_size, encoder = encoder)
    elif size == 18:
        return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, output_size=output_size, encoder = encoder)
    elif size == 34:
        return ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, output_size=output_size, encoder = encoder)
    elif size == 50:
        return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, output_size=output_size, encoder = encoder)
    elif size == 101:
        return ResNet(Bottleneck, [3, 4, 23, 3], in_channels=in_channels, output_size=output_size, encoder = encoder)
    elif size == 152:
        return ResNet(Bottleneck, [3, 8, 36, 3], in_channels=in_channels, output_size=output_size, encoder = encoder)

